Binary Tree Pruning [DFS]

Time: O(N); Space: O(H); med

We are given the head node root of a binary tree, where additionally every node’s value is either a 0 or a 1. Return the same tree where every subtree (of the given tree) not containing a 1 has been removed. (Recall that the subtree of a node X is X, plus every node that is a descendant of X.)

Example 1:

Input: root = {TreeNode} [1,None,0,None,None,0,1]

Output: {TreeNode} [1,None,0,None,1]

Explanation:

  • Only the red nodes satisfy the property “every subtree not containing a 1”.

  • The diagram on the right represents the answer.

Example 2:

Input: root = {TreeNode} [1,0,1,0,0,0,1]

Output: {TreeNode} [1,None,1,None,1]

Example 3:

Input: root = {TreeNode} [1,1,0,1,1,0,1,0]

Output: {TreeNode} [1,1,0,1,1,None,1]

Notes:

  • The binary tree will have at most 100 nodes.

  • The value of each node will only be 0 or 1.

[1]:
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

Auxiliary Tools¶

[2]:
from graphviz import Graph

class TreeTasks(object):
    def visualize_tree(self, tree):
        def add_nodes_edges(tree, dot=None):
            # Create Graph (not Digraph) object
            if dot is None:
                dot = Graph()
                dot.node(name=str(tree), label=str(tree.val))
            # Add nodes
            if tree.left:
                dot.node(name=str(tree.left), label="."+str(tree.left.val))
                dot.edge(str(tree), str(tree.left))
                dot = add_nodes_edges(tree.left, dot=dot)
            if tree.right:
                dot.node(name=str(tree.right), label=str(tree.right.val)+".")
                dot.edge(str(tree), str(tree.right))
                dot = add_nodes_edges(tree.right, dot=dot)
            return dot
        # Add nodes recursively and create a list of edges
        dot = add_nodes_edges(tree)
        # Visualize the graph
        display(dot)
        return dot
[3]:
class Solution1(object):
    def pruneTree(self, root):
        """
        :type root: TreeNode
        :rtype: TreeNode
        """
        if not root:
            return None

        root.left = self.pruneTree(root.left)
        root.right = self.pruneTree(root.right)

        if not root.left and not root.right and root.val == 0:
            return None

        return root
[4]:
s = Solution1()

root = TreeNode(1)
root.right = TreeNode(0)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.right.val == 0
# assert res.right.right.val == 1
../../_images/topics_tree_0814_binary_tree_pruning_[O(N),O(H),med]_5_0.svg
[5]:
root = TreeNode(1)
root.left = TreeNode(0)
root.right = TreeNode(1)
root.left.left = TreeNode(0)
root.left.right = TreeNode(0)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.right.val == 1
# assert res.right.right.val == 1
../../_images/topics_tree_0814_binary_tree_pruning_[O(N),O(H),med]_6_0.svg
[6]:
root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(0)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
root.left.left.left = TreeNode(0)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.left.val == 1
# assert res.right.val == 0
# assert res.left.left.val == 1
# assert res.left.right.val == 1
# assert res.right.right.val == 1
../../_images/topics_tree_0814_binary_tree_pruning_[O(N),O(H),med]_7_0.svg